/*
 * Copyright 2019 ACINQ SAS
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64}
import fr.acinq.eclair.crypto.Hmac256
import fr.acinq.eclair.wire.protocol.FailureMessageCodecs._
import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, TimestampSecond, TimestampSecondLong, UInt64, randomBytes32}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits._

/**
 * Created by PM on 31/05/2016.
 */

class FailureMessageCodecsSpec extends AnyFunSuite {

  test("encode/decode all failure messages") {
    val channelUpdate = ChannelUpdate(
      signature = ByteVector64(hex"3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e620"),
      chainHash = Block.RegtestGenesisBlock.hash,
      shortChannelId = ShortChannelId(12345),
      timestamp = TimestampSecond(1234567L),
      cltvExpiryDelta = CltvExpiryDelta(100),
      messageFlags = ChannelUpdate.MessageFlags(dontForward = false),
      channelFlags = ChannelUpdate.ChannelFlags(isEnabled = true, isNode1 = false),
      htlcMinimumMsat = 1000 msat,
      feeBaseMsat = 12 msat,
      feeProportionalMillionths = 76,
      htlcMaximumMsat = 150_000_000 msat)
    val testCases = Map[FailureMessage, ByteVector](
      InvalidRealm() -> hex"4001",
      TemporaryNodeFailure() -> hex"2002",
      TemporaryNodeFailure(TlvStream(Set.empty[FailureMessageTlv], Set(GenericTlv(UInt64(561), hex"deadbeef"), GenericTlv(UInt64(1105), hex"0102030405")))) -> hex"2002 fd023104deadbeef fd0451050102030405",
      PermanentNodeFailure() -> hex"6002",
      RequiredNodeFeatureMissing() -> hex"6003",
      InvalidOnionVersion(ByteVector32(hex"d8db0e777047d814f569c8243073be42e56a411ebfe82fd877ba958fb068ae3e")) -> hex"c004 d8db0e777047d814f569c8243073be42e56a411ebfe82fd877ba958fb068ae3e",
      InvalidOnionHmac(ByteVector32(hex"1c9836f65130ee10a13da0db2d2acef8bc799978d351700f1a09aefc3ab221f7")) -> hex"c005 1c9836f65130ee10a13da0db2d2acef8bc799978d351700f1a09aefc3ab221f7",
      InvalidOnionKey(ByteVector32(hex"7568cf300a7b7458693904d50e67dc0c29a5116600d93e9979c3fa91e2b85395")) -> hex"c006 7568cf300a7b7458693904d50e67dc0c29a5116600d93e9979c3fa91e2b85395",
      InvalidOnionBlinding(ByteVector32(hex"d71cd923bb201254bc07dadde7e795b8c6b0b849325ee3c603e1bba2e5d2c100")) -> hex"c018 d71cd923bb201254bc07dadde7e795b8c6b0b849325ee3c603e1bba2e5d2c100",
      TemporaryChannelFailure(Some(channelUpdate)) -> hex"1007 008a 0102 3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e62006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00000000000030390012d6870101006400000000000003e80000000c0000004c0000000008f0d180",
      TemporaryChannelFailure(None) -> hex"1007 0000",
      PermanentChannelFailure() -> hex"4008",
      RequiredChannelFeatureMissing() -> hex"4009",
      UnknownNextPeer() -> hex"400a",
      AmountBelowMinimum(123456 msat, Some(channelUpdate)) -> hex"100b 000000000001e240 008a 0102 3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e62006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00000000000030390012d6870101006400000000000003e80000000c0000004c0000000008f0d180",
      FeeInsufficient(546463 msat, Some(channelUpdate)) -> hex"100c 000000000008569f 008a 0102 3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e62006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00000000000030390012d6870101006400000000000003e80000000c0000004c0000000008f0d180",
      ChannelDisabled(ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = true, isNode1 = false), Some(channelUpdate)) -> hex"1014 01 01 008a 0102 3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e62006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00000000000030390012d6870101006400000000000003e80000000c0000004c0000000008f0d180",
      IncorrectCltvExpiry(CltvExpiry(1211), Some(channelUpdate)) -> hex"100d 000004bb 008a 0102 3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e62006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00000000000030390012d6870101006400000000000003e80000000c0000004c0000000008f0d180",
      IncorrectOrUnknownPaymentDetails(123456 msat, BlockHeight(1105)) -> hex"400f 000000000001e240 00000451",
      IncorrectOrUnknownPaymentDetails(100 msat, BlockHeight(800_000), TlvStream(Set.empty[FailureMessageTlv], Set(GenericTlv(UInt64(34001), ByteVector.fill(300)(128))))) -> hex"400f 0000000000000064 000c3500 fd84d1 fd012c 808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080808080",
      ExpiryTooSoon(Some(channelUpdate)) -> hex"100e 008a 0102 3eedbba335d4fb4c772c95bf6a0de91b86950fabbb3a9f203beef44f30547f89eaf5dbf434520b4dcab5a0641589aa5483cdfd24902295310383f9ab8835e62006226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00000000000030390012d6870101006400000000000003e80000000c0000004c0000000008f0d180",
      FinalIncorrectCltvExpiry(CltvExpiry(1234)) -> hex"0012 000004d2",
      FinalIncorrectHtlcAmount(25_000_000 msat) -> hex"0013 00000000017d7840",
      ExpiryTooFar() -> hex"0015",
      InvalidOnionPayload(UInt64(561), 1105) -> hex"4016 fd0231 0451",
      PaymentTimeout() -> hex"0017",
      TrampolineFeeInsufficient() -> hex"2033",
      TrampolineExpiryTooSoon() -> hex"2034",
    )
    testCases.foreach { case (msg, bin) =>
      val encoded = failureMessageCodec.encode(msg).require
      assert(encoded.bytes == bin)
      val decoded = failureMessageCodec.decode(encoded).require
      assert(msg == decoded.value)
    }
  }

  test("decode unknown failure messages") {
    val testCases = Seq(
      // Deprecated incorrect_payment_amount.
      (false, true, hex"4010"),
      // Deprecated final_expiry_too_soon.
      (false, true, hex"4011"),
      // Unknown failure messages.
      (false, false, hex"00ff 42"),
      (true, false, hex"20ff 42"),
      (true, true, hex"60ff 42")
    )

    for ((node, perm, bin) <- testCases) {
      val decoded = failureMessageCodec.decode(bin.bits).require.value
      assert(decoded.isInstanceOf[FailureMessage])
      assert(decoded.isInstanceOf[UnknownFailureMessage])
      assert(decoded.isInstanceOf[Node] == node)
      assert(decoded.isInstanceOf[Perm] == perm)
    }
  }

  test("bad onion failure code") {
    val msgs = Map(
      (BADONION | PERM | 4) -> InvalidOnionVersion(randomBytes32()),
      (BADONION | PERM | 5) -> InvalidOnionHmac(randomBytes32()),
      (BADONION | PERM | 6) -> InvalidOnionKey(randomBytes32())
    )

    for ((code, message) <- msgs) {
      assert(message.code == code)
    }
  }

  test("encode/decode failure onion") {
    val codec = failureOnionCodec(Hmac256(ByteVector32.Zeroes))
    val testCases = Map(
      InvalidOnionKey(ByteVector32(hex"2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a")) -> hex"41a824e2d630111669fa3e52b600a518f369691909b4e89205dc624ee17ed2c1 0022 c006 2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a 00de 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      IncorrectOrUnknownPaymentDetails(42 msat, BlockHeight(1105)) -> hex"5eb766da1b2f45b4182e064dacd8da9eca2c9a33f0dce363ff308e9bdb3ee4e3 000e 400f 000000000000002a 00000451 00f2 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
    )

    for ((expected, bin) <- testCases) {
      val decoded = codec.decode(bin.toBitVector).require.value
      assert(decoded == expected)

      val encoded = codec.encode(expected).require.toByteVector
      assert(encoded == bin)
    }
  }

  test("decode backwards-compatible IncorrectOrUnknownPaymentDetails") {
    val codec = failureOnionCodec(Hmac256(ByteVector32.Zeroes))
    val testCases = Map(
      // Without any data.
      IncorrectOrUnknownPaymentDetails(0 msat, BlockHeight(0)) -> hex"0d83b55dd5a6086e4033c3659125ed1ff436964ce0e67ed5a03bddb16a9a1041 0002 400f 00fe 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      // With an amount but no height.
      IncorrectOrUnknownPaymentDetails(42 msat, BlockHeight(0)) -> hex"ba6e122b2941619e2106e8437bf525356ffc8439ac3b2245f68546e298a08cc6 000a 400f 000000000000002a 00f6 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      // With amount and height.
      IncorrectOrUnknownPaymentDetails(42 msat, BlockHeight(1105)) -> hex"5eb766da1b2f45b4182e064dacd8da9eca2c9a33f0dce363ff308e9bdb3ee4e3 000e 400f 000000000000002a 00000451 00f2 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
    )

    for ((expected, bin) <- testCases) {
      assert(codec.decode(bin.bits).require.value == expected)
    }
  }

  test("decode failure onion packet with arbitrary length") {
    val codec = failureOnionCodec(Hmac256(ByteVector32.Zeroes))
    val testCases = Seq(
      InvalidRealm() -> hex"7bfb2aa46218240684f623322ae48af431d06986c82e210bb0cee83c7ddb2ba8 0002 4001 0002 0000",
      IncorrectOrUnknownPaymentDetails(1105 msat, BlockHeight(1729)) -> hex"c508151d550a6a7fb121542b7c383fd7f18381832499c419de436e131c1f3a76 000e 400f 0000000000000451 000006c1 0004 deadbeef",
      InvalidRealm() -> hex"6f9e2c0e44b3692dac37523c6ff054cc9b26ecab1a78ed6906a46848bffc2bd5 0002 4001 00ff 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      IncorrectOrUnknownPaymentDetails(1105 msat, BlockHeight(1729)) -> hex"bb2873dad5447927774cb7de99f43c0b5f54f6e298b5be4d7ca88677b8f0817d 000e 400f 0000000000000451 000006c1 00ff 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
    )

    for ((expected, bin) <- testCases) {
      assert(codec.decode(bin.bits).require.value == expected)
    }
  }

  test("decode invalid failure onion packet") {
    val codec = failureOnionCodec(Hmac256(ByteVector32.Zeroes))
    val testCases = Seq(
      // Invalid failure message.
      hex"fd2f3eb163dacfa7fe2ec1a7dc73c33438e7ca97c561475cf0dc96dc15a75039 0020 c005 2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a 00e0 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      // Invalid mac.
      hex"0000000000000000000000000000000000000000000000000000000000000000 0022 c006 2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a 00de 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      // Padding length doesn't match actual padding.
      hex"8c92256e45bbe765130d952e6c043cf594ab25224701f5477fce0e50ee88fa21 0002 4001 0002 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      // Padding length doesn't match actual padding.
      hex"3898307b7c01781628ff6f854a4a78524541e4afde9b44046bdb84093f082d9d 0002 4001 00ff 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
    )

    for (testCase <- testCases) {
      assert(codec.decode(testCase.toBitVector).isFailure)
    }
  }

  test("support encoding of channel_update with/without type in failure messages") {
    val tmpChannelFailureWithoutType = hex"10070088cc3e80149073ed487c76e48e9622bf980f78267b8a34a3f61921f2d8fce6063b08e74f34a073a13f2097337e4915bb4c001f3b5c4d81e9524ed575e1f45782196fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d619000000000008260500041300005b91b52f0103000e00000000000003e800000001000000010000000008f0d180"
    val tmpChannelFailureWithType = hex"1007008a0102cc3e80149073ed487c76e48e9622bf980f78267b8a34a3f61921f2d8fce6063b08e74f34a073a13f2097337e4915bb4c001f3b5c4d81e9524ed575e1f45782196fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d619000000000008260500041300005b91b52f0103000e00000000000003e800000001000000010000000008f0d180"
    val ref = TemporaryChannelFailure(Some(ChannelUpdate(ByteVector64(hex"cc3e80149073ed487c76e48e9622bf980f78267b8a34a3f61921f2d8fce6063b08e74f34a073a13f2097337e4915bb4c001f3b5c4d81e9524ed575e1f4578219"), Block.LivenetGenesisBlock.hash, ShortChannelId(0x826050004130000L), 1536275759 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = false, isNode1 = false), CltvExpiryDelta(14), 1000 msat, 1 msat, 1, 150_000_000 msat)))

    val u1 = failureMessageCodec.decode(tmpChannelFailureWithoutType.toBitVector).require.value
    assert(u1 == ref)
    val bin = failureMessageCodec.encode(u1).require.bytes
    assert(bin == tmpChannelFailureWithType)
    val u2 = failureMessageCodec.decode(bin.toBitVector).require.value
    assert(u2 == ref)
  }

}
